{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\n",
        "\n",
        "建议直接使用google colab notebook打开本教程，可以快速下载相关数据集和模型。\n",
        "如果您正在google的colab中打开这个notebook，您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。"
      ],
      "metadata": {
        "id": "X4cRE8IbIrIV"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "! pip install datasets transformers \"sacrebleu>=1.4.12,<2.0.0\" sentencepiece"
      ],
      "outputs": [],
      "metadata": {
        "id": "MOsHUjgdIrIW"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "如果您正在本地打开这个notebook，请确保您认真阅读并安装了transformer-quick-start-zh的readme文件中的所有依赖库。您也可以在[这里](https://github.com/huggingface/transformers/tree/master/examples/seq2seq)找到本notebook的多GPU分布式训练版本。"
      ],
      "metadata": {
        "id": "HFASsisvIrIb"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 微调transformer模型解决翻译任务"
      ],
      "metadata": {
        "id": "rEJBSTyZIrIb"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "在这个notebook中，我们将展示如何使用[🤗 Transformers](https://github.com/huggingface/transformers)代码库中的模型来解决自然语言处理中的翻译任务。我们将会使用[WMT dataset](http://www.statmt.org/wmt16/)数据集。这是翻译任务最常用的数据集之一。\n",
        "\n",
        "下面展示了一个例子：\n",
        "\n",
        "![Widget inference on a translation task](https://github.com/huggingface/notebooks/blob/master/examples/images/translation.png?raw=1)\n",
        "\n",
        "对于翻译任务，我们将展示如何使用简单的加载数据集，同时针对相应的仍无使用transformer中的Trainer接口对模型进行微调。"
      ],
      "metadata": {
        "id": "kTCFado4IrIc"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "source": [
        "model_checkpoint = \"Helsinki-NLP/opus-mt-en-ro\" \n",
        "# 选择一个模型checkpoint"
      ],
      "outputs": [],
      "metadata": {
        "id": "rJvQjiUqPjhM"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "只要预训练的transformer模型包含seq2seq结构的head层，那么本notebook理论上可以使用各种各样的transformer模型[模型面板](https://huggingface.co/models)，解决任何翻译任务。\n",
        "\n",
        "本文我们使用已经训练好的[`Helsinki-NLP/opus-mt-en-ro`](https://huggingface.co/Helsinki-NLP/opus-mt-en-ro) checkpoint来做翻译任务。 "
      ],
      "metadata": {
        "id": "4RRkXuteIrIh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 加载数据"
      ],
      "metadata": {
        "id": "whPRbBNbIrIl"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "我们将会使用🤗 Datasets库来加载数据和对应的评测方式。数据加载和评测方式加载只需要简单使用load_dataset和load_metric即可。我们使用WMT数据集中的English/Romanian双语翻译。\n"
      ],
      "metadata": {
        "id": "W7QYTpxXIrIl"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "source": [
        "from datasets import load_dataset, load_metric\n",
        "\n",
        "raw_datasets = load_dataset(\"wmt16\", \"ro-en\")\n",
        "metric = load_metric(\"sacrebleu\")"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading: 2.81kB [00:00, 523kB/s]                    \n",
            "Downloading: 3.19kB [00:00, 758kB/s]                    \n",
            "Downloading: 41.0kB [00:00, 11.0MB/s]                   \n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading and preparing dataset wmt16/ro-en (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /Users/niepig/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading: 100%|██████████| 225M/225M [00:18<00:00, 12.2MB/s]\n",
            "Downloading: 100%|██████████| 23.5M/23.5M [00:16<00:00, 1.44MB/s]\n",
            "Downloading: 100%|██████████| 38.7M/38.7M [00:03<00:00, 9.82MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset wmt16 downloaded and prepared to /Users/niepig/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a. Subsequent calls will reuse this data.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading: 5.40kB [00:00, 2.08MB/s]                   \n"
          ]
        }
      ],
      "metadata": {
        "id": "IreSlFmlIrIm"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "这个datasets对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集，只需要使用对应的key（train，validation，test）即可得到相应的数据。"
      ],
      "metadata": {
        "id": "RzfPtOMoIrIu"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "source": [
        "raw_datasets"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['translation'],\n",
              "        num_rows: 610320\n",
              "    })\n",
              "    validation: Dataset({\n",
              "        features: ['translation'],\n",
              "        num_rows: 1999\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['translation'],\n",
              "        num_rows: 1999\n",
              "    })\n",
              "})"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GWiVUF0jIrIv",
        "outputId": "3151a9fc-7239-4471-a8f0-548dd68d5a89"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "给定一个数据切分的key（train、validation或者test）和下标即可查看数据。"
      ],
      "metadata": {
        "id": "u3EtYfeHIrIz"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "source": [
        "raw_datasets[\"train\"][0]\n",
        "# 我们可以看到一句英语en对应一句罗马尼亚语言ro"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'translation': {'en': 'Membership of Parliament: see Minutes',\n",
              "  'ro': 'Componenţa Parlamentului: a se vedea procesul-verbal'}}"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X6HrpprwIrIz",
        "outputId": "69f3873e-2d1f-4614-e43e-9e654277245c"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "为了能够进一步理解数据长什么样子，下面的函数将从数据集里随机选择几个例子进行展示。"
      ],
      "metadata": {
        "id": "WHUmphG3IrI3"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "source": [
        "import datasets\n",
        "import random\n",
        "import pandas as pd\n",
        "from IPython.display import display, HTML\n",
        "\n",
        "def show_random_elements(dataset, num_examples=5):\n",
        "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
        "    picks = []\n",
        "    for _ in range(num_examples):\n",
        "        pick = random.randint(0, len(dataset)-1)\n",
        "        while pick in picks:\n",
        "            pick = random.randint(0, len(dataset)-1)\n",
        "        picks.append(pick)\n",
        "    \n",
        "    df = pd.DataFrame(dataset[picks])\n",
        "    for column, typ in dataset.features.items():\n",
        "        if isinstance(typ, datasets.ClassLabel):\n",
        "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
        "    display(HTML(df.to_html()))"
      ],
      "outputs": [],
      "metadata": {
        "id": "i3j8APAoIrI3"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "source": [
        "show_random_elements(raw_datasets[\"train\"])"
      ],
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>translation</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>{'en': 'I do not believe that this is the right course.', 'ro': 'Nu cred că acesta este varianta corectă.'}</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>{'en': 'A total of 104 new jobs were created at the European Chemicals Agency, which mainly supervises our REACH projects.', 'ro': 'Un total de 104 noi locuri de muncă au fost create la Agenția Europeană pentru Produse Chimice, care, în special, supraveghează proiectele noastre REACH.'}</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>{'en': 'In view of the above, will the Council say what stage discussions for Turkish participation in joint Frontex operations have reached?', 'ro': 'Care este stadiul negocierilor referitoare la participarea Turciei la operațiunile comune din cadrul Frontex?'}</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>{'en': 'We now fear that if the scope of this directive is expanded, the directive will suffer exactly the same fate as the last attempt at introducing 'Made in' origin marking - in other words, that it will once again be blocked by the Council.', 'ro': 'Acum ne temem că, dacă sfera de aplicare a directivei va fi extinsă, aceasta va avea exact aceeaşi soartă ca ultima încercare de introducere a marcajului de origine \"Made in”, cu alte cuvinte, că va fi din nou blocată la Consiliu.'}</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>{'en': 'The country dropped nine slots to 85th, with a score of 6.58.', 'ro': 'Ţara a coborât nouă poziţii, pe locul 85, cu un scor de 6,58.'}</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ]
          },
          "metadata": {}
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 255
        },
        "id": "SZy5tRB_IrI7",
        "outputId": "93e16172-d927-457d-fcab-04dcb4d2ef29"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "metric是[`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric)类的一个实例，查看metric和使用的例子:"
      ],
      "metadata": {
        "id": "lnjDIuQ3IrI-"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "source": [
        "metric"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Metric(name: \"sacrebleu\", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}, usage: \"\"\"\n",
              "Produces BLEU scores along with its sufficient statistics\n",
              "from a source against one or more references.\n",
              "\n",
              "Args:\n",
              "    predictions: The system stream (a sequence of segments)\n",
              "    references: A list of one or more reference streams (each a sequence of segments)\n",
              "    smooth: The smoothing method to use\n",
              "    smooth_value: For 'floor' smoothing, the floor to use\n",
              "    force: Ignore data that looks already tokenized\n",
              "    lowercase: Lowercase the data\n",
              "    tokenize: The tokenizer to use\n",
              "Returns:\n",
              "    'score': BLEU score,\n",
              "    'counts': Counts,\n",
              "    'totals': Totals,\n",
              "    'precisions': Precisions,\n",
              "    'bp': Brevity penalty,\n",
              "    'sys_len': predictions length,\n",
              "    'ref_len': reference length,\n",
              "Examples:\n",
              "\n",
              "    >>> predictions = [\"hello there general kenobi\", \"foo bar foobar\"]\n",
              "    >>> references = [[\"hello there general kenobi\", \"hello there !\"], [\"foo bar foobar\", \"foo bar foobar\"]]\n",
              "    >>> sacrebleu = datasets.load_metric(\"sacrebleu\")\n",
              "    >>> results = sacrebleu.compute(predictions=predictions, references=references)\n",
              "    >>> print(list(results.keys()))\n",
              "    ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len']\n",
              "    >>> print(round(results[\"score\"], 1))\n",
              "    100.0\n",
              "\"\"\", stored examples: 0)"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5o4rUteaIrI_",
        "outputId": "4814f907-6225-4af0-ee63-376699dc79ee"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "我们使用`compute`方法来对比predictions和labels，从而计算得分。predictions和labels都需要是一个list。具体格式见下面的例子："
      ],
      "metadata": {
        "id": "jAWdqcUBIrJC"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "source": [
        "fake_preds = [\"hello there\", \"general kenobi\"]\n",
        "fake_labels = [[\"hello there\"], [\"general kenobi\"]]\n",
        "metric.compute(predictions=fake_preds, references=fake_labels)"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'score': 0.0,\n",
              " 'counts': [4, 2, 0, 0],\n",
              " 'totals': [4, 2, 0, 0],\n",
              " 'precisions': [100.0, 100.0, 0.0, 0.0],\n",
              " 'bp': 1.0,\n",
              " 'sys_len': 4,\n",
              " 'ref_len': 4}"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6XN1Rq0aIrJC",
        "outputId": "d130ad50-c6ca-42bc-8b14-31021feb620d"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 数据预处理"
      ],
      "metadata": {
        "id": "n9qywopnIrJH"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "在将数据喂入模型之前，我们需要对数据进行预处理。预处理的工具叫Tokenizer。Tokenizer首先对输入进行tokenize，然后将tokens转化为预模型中需要对应的token ID，再转化为模型需要的输入格式。\n",
        "\n",
        "为了达到数据预处理的目的，我们使用AutoTokenizer.from_pretrained方法实例化我们的tokenizer，这样可以确保：\n",
        "\n",
        "- 我们得到一个与预训练模型一一对应的tokenizer。\n",
        "- 使用指定的模型checkpoint对应的tokenizer的时候，我们也下载了模型需要的词表库vocabulary，准确来说是tokens vocabulary。\n",
        "\n",
        "\n",
        "这个被下载的tokens vocabulary会被缓存起来，从而再次使用的时候不会重新下载。"
      ],
      "metadata": {
        "id": "YVx71GdAIrJH"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "source": [
        "from transformers import AutoTokenizer\n",
        "# 需要安装`sentencepiece`： pip install sentencepiece\n",
        "    \n",
        "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading: 100%|██████████| 1.13k/1.13k [00:00<00:00, 466kB/s]\n",
            "Downloading: 100%|██████████| 789k/789k [00:00<00:00, 882kB/s]\n",
            "Downloading: 100%|██████████| 817k/817k [00:00<00:00, 902kB/s]\n",
            "Downloading: 100%|██████████| 1.39M/1.39M [00:01<00:00, 1.24MB/s]\n",
            "Downloading: 100%|██████████| 42.0/42.0 [00:00<00:00, 14.6kB/s]\n"
          ]
        }
      ],
      "metadata": {
        "id": "eXNLu_-nIrJI"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "以我们使用的mBART模型为例，我们需要正确设置source语言和target语言。如果您要翻译的是其他双语语料，请查看[这里](https://huggingface.co/facebook/mbart-large-cc25)。我们可以检查source和target语言的设置：\n"
      ],
      "metadata": {
        "id": "GLRyc5J9PjhS"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "source": [
        "if \"mbart\" in model_checkpoint:\n",
        "    tokenizer.src_lang = \"en-XX\"\n",
        "    tokenizer.tgt_lang = \"ro-RO\""
      ],
      "outputs": [],
      "metadata": {
        "id": "kmXG36baPjhS"
      }
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "Vl6IidfdIrJK"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "tokenizer既可以对单个文本进行预处理，也可以对一对文本进行预处理，tokenizer预处理后得到的数据满足预训练模型输入格式"
      ],
      "metadata": {
        "id": "rowT4iCLIrJK"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "source": [
        "tokenizer(\"Hello, this one sentence!\")"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input_ids': [125, 778, 3, 63, 141, 9191, 23, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}"
            ]
          },
          "metadata": {},
          "execution_count": 12
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a5hBlsrHIrJL",
        "outputId": "072ee20c-db1d-4ba1-a98a-119405ea9552"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "上面看到的token IDs也就是input_ids一般来说随着预训练模型名字的不同而有所不同。原因是不同的预训练模型在预训练的时候设定了不同的规则。但只要tokenizer和model的名字一致，那么tokenizer预处理的输入格式就会满足model需求的。关于预处理更多内容参考[这个教程](https://huggingface.co/transformers/preprocessing.html)\n",
        "\n",
        "除了可以tokenize一句话，我们也可以tokenize一个list的句子。"
      ],
      "metadata": {
        "id": "qo_0B1M2IrJM"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "source": [
        "tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"])"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 0], [187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LkLffVlKPjhT",
        "outputId": "f144d050-fc84-4a1a-9fc2-25281b681441"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "注意：为了给模型准备好翻译的targets，我们使用`as_target_tokenizer`来控制targets所对应的特殊token："
      ],
      "metadata": {
        "id": "-uVqYJrePjhT"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "source": [
        "with tokenizer.as_target_tokenizer():\n",
        "    print(tokenizer(\"Hello, this one sentence!\"))\n",
        "    model_input = tokenizer(\"Hello, this one sentence!\")\n",
        "    tokens = tokenizer.convert_ids_to_tokens(model_input['input_ids'])\n",
        "    # 打印看一下special toke\n",
        "    print('tokens: {}'.format(tokens))"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'input_ids': [10334, 1204, 3, 15, 8915, 27, 452, 59, 29579, 581, 23, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n",
            "tokens: ['▁Hel', 'lo', ',', '▁', 'this', '▁o', 'ne', '▁se', 'nten', 'ce', '!', '</s>']\n"
          ]
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DgCW0X0FPjhT",
        "outputId": "352c44ab-f025-4cf6-98d1-786f6f07111a"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "如果您使用的是T5预训练模型的checkpoints，需要对特殊的前缀进行检查。T5使用特殊的前缀来告诉模型具体要做的任务，具体前缀例子如下：\n"
      ],
      "metadata": {
        "id": "2C0hcmp9IrJQ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "source": [
        "if model_checkpoint in [\"t5-small\", \"t5-base\", \"t5-larg\", \"t5-3b\", \"t5-11b\"]:\n",
        "    prefix = \"translate English to Romanian: \"\n",
        "else:\n",
        "    prefix = \"\""
      ],
      "outputs": [],
      "metadata": {
        "id": "xS1JJSdmPjhU"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "现在我们可以把所有内容放在一起组成我们的预处理函数了。我们对样本进行预处理的时候，我们还会`truncation=True`这个参数来确保我们超长的句子被截断。默认情况下，对与比较短的句子我们会自动padding。"
      ],
      "metadata": {
        "id": "CezpZ8gFPjhU"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "source": [
        "max_input_length = 128\n",
        "max_target_length = 128\n",
        "source_lang = \"en\"\n",
        "target_lang = \"ro\"\n",
        "\n",
        "def preprocess_function(examples):\n",
        "    inputs = [prefix + ex[source_lang] for ex in examples[\"translation\"]]\n",
        "    targets = [ex[target_lang] for ex in examples[\"translation\"]]\n",
        "    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n",
        "\n",
        "    # Setup the tokenizer for targets\n",
        "    with tokenizer.as_target_tokenizer():\n",
        "        labels = tokenizer(targets, max_length=max_target_length, truncation=True)\n",
        "\n",
        "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
        "    return model_inputs"
      ],
      "outputs": [],
      "metadata": {
        "id": "vc0BSBLIIrJQ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "以上的预处理函数可以处理一个样本，也可以处理多个样本exapmles。如果是处理多个样本，则返回的是多个样本被预处理之后的结果list。"
      ],
      "metadata": {
        "id": "0lm8ozrJIrJR"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "source": [
        "preprocess_function(raw_datasets['train'][:2])"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input_ids': [[393, 4462, 14, 1137, 53, 216, 28636, 0], [24385, 14, 28636, 14, 4646, 4622, 53, 216, 28636, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[42140, 494, 1750, 53, 8, 59, 903, 3543, 9, 15202, 0], [36199, 6612, 9, 15202, 122, 568, 35788, 21549, 53, 8, 59, 903, 3543, 9, 15202, 0]]}"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-b70jh26IrJS",
        "outputId": "89b26088-d2d2-4312-81d8-b0f5e62dd6a2"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "接下来对数据集datasets里面的所有样本进行预处理，处理的方式是使用map函数，将预处理函数prepare_train_features应用到（map)所有样本上。"
      ],
      "metadata": {
        "id": "zS-6iXTkIrJT"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "source": [
        "tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 611/611 [02:32<00:00,  3.99ba/s]\n",
            "100%|██████████| 2/2 [00:00<00:00,  3.76ba/s]\n",
            "100%|██████████| 2/2 [00:00<00:00,  3.89ba/s]\n"
          ]
        }
      ],
      "metadata": {
        "id": "DDtsaJeVIrJT"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "更好的是，返回的结果会自动被缓存，避免下次处理的时候重新计算（但是也要注意，如果输入有改动，可能会被缓存影响！）。datasets库函数会对输入的参数进行检测，判断是否有变化，如果没有变化就使用缓存数据，如果有变化就重新处理。但如果输入参数不变，想改变输入的时候，最好清理调这个缓存。清理的方式是使用`load_from_cache_file=False`参数。另外，上面使用到的`batched=True`这个参数是tokenizer的特点，以为这会使用多线程同时并行对输入进行处理。"
      ],
      "metadata": {
        "id": "voWiw8C7IrJV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 微调transformer模型"
      ],
      "metadata": {
        "id": "545PP3o8IrJV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "既然数据已经准备好了，现在我们需要下载并加载我们的预训练模型，然后微调预训练模型。既然我们是做seq2seq任务，那么我们需要一个能解决这个任务的模型类。我们使用`AutoModelForSeq2SeqLM`这个类。和tokenizer相似，`from_pretrained`方法同样可以帮助我们下载并加载模型，同时也会对模型进行缓存，就不会重复下载模型啦。"
      ],
      "metadata": {
        "id": "FBiW8UpKIrJW"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "source": [
        "from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
        "\n",
        "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading: 100%|██████████| 301M/301M [00:19<00:00, 15.1MB/s]\n"
          ]
        }
      ],
      "metadata": {
        "id": "TlqNaB8jIrJW"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "由于我们微调的任务是机器翻译，而我们加载的是预训练的seq2seq模型，所以不会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数（比如：预训练语言模型的神经网络head被扔掉了，同时随机初始化了机器翻译的神经网络head）。"
      ],
      "metadata": {
        "id": "CczA5lJlIrJX"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "为了能够得到一个`Seq2SeqTrainer`训练工具，我们还需要3个要素，其中最重要的是训练的设定/参数[`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments)。这个训练设定包含了能够定义训练过程的所有属性"
      ],
      "metadata": {
        "id": "_N8urzhyIrJY"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "source": [
        "batch_size = 16\n",
        "args = Seq2SeqTrainingArguments(\n",
        "    \"test-translation\",\n",
        "    evaluation_strategy = \"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=batch_size,\n",
        "    per_device_eval_batch_size=batch_size,\n",
        "    weight_decay=0.01,\n",
        "    save_total_limit=3,\n",
        "    num_train_epochs=1,\n",
        "    predict_with_generate=True,\n",
        "    fp16=False,\n",
        ")"
      ],
      "outputs": [],
      "metadata": {
        "id": "Bliy8zgjIrJY"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "上面evaluation_strategy = \"epoch\"参数告诉训练代码：我们每个epcoh会做一次验证评估。\n",
        "\n",
        "上面batch_size在这个notebook之前定义好了。\n",
        "\n",
        "由于我们的数据集比较大，同时`Seq2SeqTrainer`会不断保存模型，所以我们需要告诉它至多保存`save_total_limit=3`个模型。\n",
        "\n",
        "最后我们需要一个数据收集器data collator，将我们处理好的输入喂给模型。"
      ],
      "metadata": {
        "id": "km3pGVdTIrJc"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "source": [
        "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
      ],
      "outputs": [],
      "metadata": {
        "id": "ZMdgZaOoPjhX"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "设置好`Seq2SeqTrainer`还剩最后一件事情，那就是我们需要定义好评估方法。我们使用`metric`来完成评估。将模型预测送入评估之前，我们也会做一些数据后处理："
      ],
      "metadata": {
        "id": "7sZOdRlRIrJd"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "source": [
        "import numpy as np\n",
        "\n",
        "def postprocess_text(preds, labels):\n",
        "    preds = [pred.strip() for pred in preds]\n",
        "    labels = [[label.strip()] for label in labels]\n",
        "\n",
        "    return preds, labels\n",
        "\n",
        "def compute_metrics(eval_preds):\n",
        "    preds, labels = eval_preds\n",
        "    if isinstance(preds, tuple):\n",
        "        preds = preds[0]\n",
        "    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
        "\n",
        "    # Replace -100 in the labels as we can't decode them.\n",
        "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
        "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
        "\n",
        "    # Some simple post-processing\n",
        "    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n",
        "\n",
        "    result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
        "    result = {\"bleu\": result[\"score\"]}\n",
        "\n",
        "    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]\n",
        "    result[\"gen_len\"] = np.mean(prediction_lens)\n",
        "    result = {k: round(v, 4) for k, v in result.items()}\n",
        "    return result"
      ],
      "outputs": [],
      "metadata": {
        "id": "UmvbnJ9JIrJd"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "最后将所有的参数/数据/模型传给`Seq2SeqTrainer`即可"
      ],
      "metadata": {
        "id": "rXuFTAzDIrJe"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "source": [
        "trainer = Seq2SeqTrainer(\n",
        "    model,\n",
        "    args,\n",
        "    train_dataset=tokenized_datasets[\"train\"],\n",
        "    eval_dataset=tokenized_datasets[\"validation\"],\n",
        "    data_collator=data_collator,\n",
        "    tokenizer=tokenizer,\n",
        "    compute_metrics=compute_metrics\n",
        ")"
      ],
      "outputs": [],
      "metadata": {
        "id": "imY1oC3SIrJf"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "调用`train`方法进行微调训练。"
      ],
      "metadata": {
        "id": "CdzABDVcIrJg"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "trainer.train()"
      ],
      "outputs": [],
      "metadata": {
        "id": "uNx5pyRlIrJh",
        "scrolled": false
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "最后别忘了，查看如何上传模型 ，上传模型到](https://huggingface.co/transformers/model_sharing.html) 到[🤗 Model Hub](https://huggingface.co/models)。随后您就可以像这个notebook一开始一样，直接用模型名字就能使用您的模型啦。\n"
      ],
      "metadata": {
        "id": "JXOyGJtqPjhZ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [],
      "outputs": [],
      "metadata": {
        "id": "Jeq1Cq2yPjhZ"
      }
    }
  ],
  "metadata": {
    "colab": {
      "name": "4.6-生成任务-机器翻译",
      "provenance": []
    },
    "interpreter": {
      "hash": "3bfce0b4c492a35815b5705a19fe374a7eea0baaa08b34d90450caf1fe9ce20b"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)"
    },
    "language_info": {
      "name": "python",
      "version": "3.8.10",
      "mimetype": "text/x-python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "file_extension": ".py"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}